--- support for 'Decimal' numbers

module frege.data.Dec64 
        inline (positive, Positive.isNaN, Positive.isZero, Positive.coefficient, Positive.exponent,
                Decimal.hashCode, mulu10, relevantBits)
    where


import frege.Prelude hiding (>>, ^, ~)
import Data.Bits

protected (<<)  = shiftL
protected (>>)  = shiftR
protected (>>>) = ushiftR
protected (&)   = (.&.)
protected (¦)   = (.|.)
protected (^)   = (.^.)
protected (~)   = Bits.complement


--- unsigned division by 10, using only 8 additions and 10 bit shifts
divu10 :: Long -> Long
divu10 n = result
    where
        -- n       = arg + ((arg `shiftR` 63) & 9)
        q'      = (n >> 1) + (n >> 2)
        q''     = q' + (q' >> 4)
        q'''    = q'' + (q'' >> 8)
        q''''   = q''' + (q''' >> 16)
        q'''''  = q'''' + (q'''' >> 32)
        q       = q''''' >> 3
        r       = n - (q << 3) - (q << 1)
        result  = q + ((r+6) >> 4)



{-

    int remu10(unsigned n) {
        static char table[16] = {0, 1, 2, 2, 3, 3, 4, 5, 5, 6, 7, 7, 8, 8, 9, 0};
        n = (0x19999999*n + (n >> 1) + (n >> 3)) >> 28;
        return table[n];
    }

-}

--- unsigned modulus 10, using only multiplication, addition and bit shifts
remu10 :: Long -> Long
remu10 n = (0x0988776554332210L >> (i.int<<2)) & 0x0fL
    where
        i = (0x1999999999999999L*n + (n>>>1) + (n>>>3)) >>> 60

--- multiplication by 10, with shift and add
--- this will be inlined, so only use strict variables as argument
mulu10 :: Long → Long
mulu10 !n = ((n<<2)+n)<<1        -- ((n*4)+1)*2

--- We will do the real work with positive 'Decimal's only and afterwards negate the result if need be
--- This will reduce the cases to be considered and ease overflow detection
data Positive = pure native "long" where
    {-- 
        Interpret the bits of the 'Long' value as a 'Positive'
        This is a no-operation, as 'Positive's *are*  'Long's and every 'Long' value is a valid 'Positive'
        (though it might not be a number, see also 'Positive.nan').
    -} 
    pure native fromBits "(long)" :: Long    → Positive
    {-- 
        Interpret the bits of a 'Positive' as 'Long' value.
        The result will look like this
        
        > mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm mmmm eeee eeee
        >         54        48        40        32        24        16        8 7        0
        - The 8  e bits 0..7  are the exponent in 2's complement encoding
        - The 56 m bits 8..63 are the unsigned coefficient.
        - an exponent of -128 (0x80) signals not-a-number (NaN)
        - the value of a number that is not NaN is coefficient*(exp^10)
        
        Note that there are usually several encodings for one and the same value.
        1, for example, could be 0x100 or 0xAFF (10E-1)
        Likewise 5 could be represented as 5e0, 50e-1, 0.5e1 and so forth.
        
        This is a no-operation, as 'Positive's *are* 'Long's. 
        However, the Frege compiler sees them as totally unrelated types (and rightly so). 
    -}
    pure native toBits   "(long)" :: Positive → Long
    
    --- the smalles possible coefficient is 0
    pure native minCoefficient " 0L" :: Long
    
    --- the largest possible coefficient is 72\_057\_594\_037\_927\_935
    pure native maxCoefficient " 72057594037927935L" :: Long
    
    --- the smallest exponent is -127
    pure native minExponent " 0xffffff81" :: Int
    
    --- the largest exponent is 127
    pure native maxExponent " 127" :: Int
    
    --- The canonical _not a number_ value.
    --- Note that there are 2^56 NaN values, all comparing equal.
    pure native nan " 0xDeadBeef80L" :: Positive
    
    --- The canonical 'Positive' 0
    pure native zero " 0L" :: Positive
    
    --- The canonical 'Positive' 1
    pure native one  " 0x100L" :: Positive
    
    --- tell if a 'Positive' is not a number. This works for all 2^56 NaN values.
    isNaN d = (toBits d) & 0xffL == 128
    
    --- tell if a 'Positive' is 0. This works for all 255 0 values.
    isZero d = not (isNaN d) && d.coefficient == 0
    
    --- extract the coefficient from a 'Positive'
    --- This is simply an unsigned right shift by 8
    coefficient d = toBits d  >>>  8

    --- extract the exponent, sign extended in an 'Int'
    exponent d = ((toBits d & 0xffL).int Int.`shiftL` 24) Int.`shiftR` 24         -- sign extend
    
    {-- 
        construct a 'Positive' from a coefficient and an exponent
        
        This will be 'Positive.nan' if the coefficient is not in the range 'minCoefficient' .. 'maxCoefficient'
        or the exponent is not in the range 'minExponent' .. 'maxExponent', *even if* the value could be 
        represented. Thus:
        
        > Positive.pack 72057594037927935 0     == 72057594037927935z
        > Positive.pack 720575940379279350 (-1) == nan
        
         
    --}
    pack coeff !exp | coeff < minCoefficient = nan
                    | coeff > maxCoefficient = nan
                    | exp   < minExponent    = nan
                    | exp   > maxExponent    = nan
                    | otherwise              = fromBits ((coeff << 8) ¦ (exp .&. 0xFF).long) 

instance Show Positive where
        show d  | d.isNaN         = "NaN"
                | d.exponent == 0 = show d.coefficient ++ warn
                | otherwise       = show d.coefficient ++ "e" ++ show d.exponent ++ warn
                where warn = if d.coefficient > Decimal.maxCoefficient then "!" else ""


--- a coeficient cannot be multiplied by 10 if it is greater than 922_337_203_685_477_580
pure native expansionMax " 922337203685477580L" :: Long

--- a number greater than this, multiplied by 10, wouldn't fit anymore in the 56 coefficient bits
pure native coeffExpansionMax " 7205759403792793L" :: Long

--- get the absolute value of a 'Decimal' as 'Positive' 
positive ∷ Decimal → Positive
positive d = Positive.pack (abs d.coefficient) d.exponent

--- convert a 'Positive' back to a 'Decimal' with the sign indicated by the first argument
decimal ∷ Int → Positive → Decimal
decimal !sign !p
    | p.isNaN   = Decimal.nan
    | p.isZero  = Decimal.zero
    | otherwise = go sign p.coefficient p.exponent
    where
        go ∷ Int → Long → Int → Decimal
        go !sign !c !x
            -- traceLn ("decimal.go %d %d".format c x) = Decimal.nan 
            | if sign <0 then c > abs Decimal.minCoefficient else c > Decimal.maxCoefficient, 
              x < Decimal.maxExponent = go sign (divu10 (c+5)) (x+1)
            | x > Decimal.maxExponent, c <= coeffExpansionMax = go sign (mulu10 c) (x-1)
            | x < Decimal.minExponent = go sign (divu10 (c+5)) (x+1) 
            | sign < 0  = Decimal.pack (negate c) x
            | otherwise = Decimal.pack c x

--- find the representation of a 'Positive' with the smallest coefficient
--- NaN and 0 come out in their canonical forms
normRight ∷ Positive → Positive
normRight p | p.isNaN   = Positive.nan
            | p.isZero  = Positive.zero 
            | otherwise = go p.coefficient p.exponent
    where
        go :: Long → Int → Positive
        go c !e | remu10 c != 0 || e >= Positive.maxExponent 
                            = Positive.pack c e     -- can't shift right unless there are trailing zeroes
                                                    -- and exp allows it
                | otherwise = go (divu10 c) (e+1)

--- find the representation of a 'Positive' with the greatest coefficient
--- NaN and 0 come out in their canonical forms
normLeft ∷ Positive → Positive
normLeft p  | p.isNaN   = Positive.nan
            | p.isZero  = Positive.zero
            | otherwise = go p.coefficient p.exponent
    where
        go :: Long → Int → Positive
        go !c e | e > Positive.minExponent, c <= coeffExpansionMax = go (mulu10 c) (e-1)
                | otherwise = Positive.pack c e

--- represent a 'Decimal' with the greatest valid coefficient
--- NaN and 0 come out in their canonical forms
normLeftDecimal ∷ Decimal → Decimal
normLeftDecimal d | d.isNaN  = Decimal.nan
                  | d.isZero = Decimal.zero
                  | otherwise = go d.coefficient d.exponent
  where
    go :: Long → Int → Decimal
    go m x = case  m*10 of
        m'  | x  >  Decimal.minExponent, 
              m' >= Decimal.minCoefficient, 
              m' <= Decimal.maxCoefficient = go m' (x-1)
            | otherwise = Decimal.pack m x  

--- represent a 'Positive' with the greatest valid coefficient
--- NaN and 0 come out in their canonical forms
normLeftPositive ∷ Positive → Positive
normLeftPositive p | p.isNaN = Positive.nan
                   | p.isZero = Positive.zero
                   | otherwise = go p.coefficient p.exponent
  where
    go :: Long → Int → Positive
    go m x = case  mulu10 m of
        m'  | x  >  Positive.minExponent, 
              m' <= Positive.maxCoefficient = go m' (x-1)
            | otherwise = Positive.pack m x  

ulpP p = case normLeftPositive p of
    n ->  Positive.pack 1 n.exponent
 
--- find the smallest number that makes a difference when adding it to the given number
ulp d = case normLeftDecimal d of
    n -> Decimal.pack 1 n.exponent

--- given a 'Decimal' x find a 'Decimal' y such that there is no 'Decimal' z with x < z < y
nextUp d = d + ulp d

--- given a 'Decimal' x find a 'Decimal' y such that there is no 'Decimal' z with x > z > y
nextDown d = d - ulp d  

compareP ∷ Positive → Positive → Ordering
compareP a b 
    | a.isZero                  = if b.isZero then EQ else LT    -- since there are no negative numbers
    | b.isZero                  = GT                             -- since a is not 0
    | a.exponent == b. exponent = a.coefficient <=> b.coefficient           -- use 'Long' comparision
    | otherwise = case normLeft a of
        a → case normLeft b of 
            b --traceLn ("compareP a=%s b=%s".format (show a) (show b)) = LT
              | a.exponent == b.exponent = a.coefficient <=> b.coefficient  -- use 'Long' comparision
              | otherwise                = a.exponent    <=> b.exponent


compareD ∷ Decimal → Decimal → Ordering
compareD a b
            | a.isNaN    = if b.isNaN then EQ else LT
            | b.isNaN    = LT
            | a.sign < 0 = if b.sign < 0 then compareP (positive b) (positive a) else LT
            | b.sign < 0 = GT
            | otherwise  = compareP (positive a) (positive b)

instance Ord Decimal where
    (<=>) = compareD
    hashCode x = x.toBits.hashCode
    a < b   | a.isNaN = false
            | b.isNaN = false
            | otherwise = case a <=> b of
                LT → true
                _  → false
    a > b   | a.isNaN = false
            | b.isNaN = false
            | otherwise = case a <=> b of
                GT → true
                _  → false
    a <= b  | a.isNaN = b.isNaN
            | b.isNaN = false
            | otherwise = case a <=> b of
                GT → false
                _  → true
    a >= b  | a.isNaN = b.isNaN
            | b.isNaN = false
            | otherwise = case a <=> b of
                LT → false
                _  → true
    a == b  | a.isNaN = b.isNaN
            | b.isNaN = false
            | otherwise = case a <=> b of
                EQ → true
                _  → false
    a != b  = not (a == b)
                

instance Ord Positive where
    (<=>) = compareP
    hashCode x = x.toBits.hashCode
    a < b   | a.isNaN = false
            | b.isNaN = false
            | otherwise = case a <=> b of
                LT → true
                _  → false
    a > b   | a.isNaN = false
            | b.isNaN = false
            | otherwise = case a <=> b of
                GT → true
                _  → false
    a <= b  | a.isNaN = b.isNaN
            | b.isNaN = false
            | otherwise = case a <=> b of
                GT → false
                _  → true
    a >= b  | a.isNaN = b.isNaN
            | b.isNaN = false
            | otherwise = case a <=> b of
                LT → false
                _  → true
    a == b  | a.isNaN = b.isNaN
            | b.isNaN = false
            | otherwise = case a <=> b of
                EQ → true
                _  → false
    a != b  = not (a == b)


--- tell the number of significant digits
sigDigits ∷ Positive → Int
sigDigits z = go p 1 0 
    where
        p = (normRight z).coefficient
        go ∷ Long → Long → Int → Int
        go p x a | p < x = a
                 | otherwise = go p (mulu10 x) (a+1) 

--- Pack an arithmetic result by dividing the coefficient by 10 until it fits.
--- This will loose precision if the coefficient came out larger than 56 bits.
packResult !c !x
    -- traceLn ("packResult c=%d x=%d".format c x) = Positive.nan
    | c > Positive.maxCoefficient   = packResult (divu10 (c+5)) (x+1)
    | x < Positive.minExponent      = packResult (divu10 (c+5)) (x+1)
    | x > Positive.maxExponent,
      c <= coeffExpansionMax        = packResult (c*10)     (x-1)
    -- traceLn ("packResult result=" ++ show (Positive.pack c x)) = Positive.nan 
    | otherwise = Positive.pack c x

--- add 2 'Positive' numbers
--- both numbers must not be NaN
addP :: Positive → Positive → Positive
addP !a !b
    | a.isZero = b
    | b.isZero = a
    | a.exponent == b.exponent = packResult (a.coefficient + b.coefficient) a.exponent
    | a.exponent >  b.exponent = go a.coefficient a.exponent b.coefficient b.exponent
    | otherwise                = go b.coefficient b.exponent a.coefficient a.exponent
    where
        -- try to make the exponents equal, first by multiplying the left operand by 10 until it would
        -- overflow a Long. Afterwards, shift the right operand to the right, which may loose precision.
        -- (The lost digits of the second operand are the ones that cannot be represented in the result.)
        -- Eventually, the exponents match, or b becomes 0, and the algorithm terminates
        -- The operand with the greater exponent must be on the left.
        go :: Long → Int → Long → Int → Positive
        go ac ax bc !bx 
            --| traceLn ("addPgo a=%dE%d b=%dE%d".format ac ax bc bx) = Positive.nan
            | bc == 0 = packResult ac ax
            | ax > bx = if ac <= coeffExpansionMax 
                            then go (mulu10 ac) (ax-1) bc bx        -- coeff times 10
                            else go ac ax (divu10 (bc+5)) (bx+1)    -- loose precision on b
            | otherwise = packResult (ac + bc) ax

--- difference between 2 'Positive' numbers, where the first one must be >= the second one
--- both numbers must not be NaN
diffP :: Positive → Positive → Positive
diffP !a !b
    | b.isZero                  = a      -- just in case ...
    | a.exponent == b.exponent  = Positive.pack (a.coefficient - b.coefficient) a.exponent
    | a.exponent >  b.exponent  = go a.coefficient a.exponent b.coefficient b.exponent
    | otherwise                 = go b.coefficient b.exponent a.coefficient a.exponent
    where
        -- try to make the exponents equal, first by multiplying the left operand by 10 until it would
        -- overflow a Long. Afterwards, shift the right operand to the right, which may loose precision.
        -- (The lost digits of the second operand are the ones that cannot be represented in the result.)
        -- Eventually, the exponents match, or b becomes 0, and the algorithm terminates
        -- The operand with the greater exponent must be on the left.
        go :: Long → Int → Long → Int → Positive
        go ac ax bc !bx
            -- traceLn ("diffP.go a=%d,%d b=%d,%d".format ac ax bc bx) = Positive.nan 
            | bc == 0 = packResult ac ax
            | ax > bx = if ac <= coeffExpansionMax 
                            then go (mulu10 ac) (ax-1) bc bx        -- coeff times 10
                            else go ac ax (divu10 (bc+5)) (bx+1)    -- loose precision on b
            | otherwise = packResult (abs (ac - bc)) ax

pure native leadingZeroes java.lang.Long.numberOfLeadingZeros :: Long -> Int
relevantBits long = 64 - leadingZeroes long

mulP !a !b 
    | a.isZero = Positive.zero
    | b.isZero = Positive.zero
    | otherwise = case (normRight a, normRight b) of
        (!a, !b) → if a <b 
            then mulP' a.coefficient a.exponent b.coefficient b.exponent
            else mulP' b.coefficient b.exponent a.coefficient a.exponent
    where
        mulP' ∷ Long → Int → Long → Int → Positive
        mulP' !ac !ax !bc !bx 
            -- traceLn ("mulP' a=%dE%d b=%dE%d".format ac ax bc bx) = Positive.nan
            -- fast multiplication
            | relevantBits ac + relevantBits bc < 64 = packResult (ac*bc) (ax+bx)
            -- try to make the coefficients as small as possible
            | remu10 ac == 0, ax < Positive.maxExponent = mulP' (divu10 ac) (ax+1) bc bx
            | remu10 bc == 0, bx < Positive.maxExponent = mulP' ac ax (divu10 bc) (bx+1)
            --| even ac, 
            --  even bc = case mulP' (ac >>> 1) ax (bc >>> 1) bx of 
            --                p | not p.isNaN → case addP p p of 
            --                        p2 | not p2.isNaN → addP p2 p2
            --                           | otherwise    → Positive.nan 
            --                  | otherwise → Positive.nan  
            --| even ac = case mulP' (ac >>> 1) ax bc bx of 
            --                p | not p.isNaN → addP p p
            --                  | otherwise   → Positive.nan
            --| even bc = case mulP' ac ax (bc >>> 1) bx of
            --                p | not p.isNaN → addP p p
            --                  | otherwise   → Positive.nan 
            --| traceLn ("mulP' cannot compute fast a=%dE%d b=%dE%d".format ac ax bc bx) = Positive.nan
            -- 456eX * 123eY = 456ex*3eY + 456eX*12e(Y+1)
            -- apply the distributive law, take care of multiplications that are NaN!
            | ax < bx, !p1 ← mulP' bc bx (remu10 ac) ax, !pr ← mulP' bc bx (divu10 ac) (ax+1)
                        = if p1.isNaN || pr.isNaN then Positive.nan else addP pr p1 
            | !p1 ← mulP' ac ax (remu10 bc) bx, !pr ← mulP' ac ax (divu10 bc) (bx+1) 
                        = if p1.isNaN || pr.isNaN then Positive.nan else addP pr p1

recipP ∷ Positive → Positive
recipP !a
    | a.isZero = Positive.nan
    | a.isNaN  = Positive.nan
    | a >= Positive.pack 1 127 = Positive.pack 1 (-127)   
    | otherwise = case sigDigits a of
        !d → case normRight a of
            !a → case go (Positive.pack a.coefficient 0) (Positive.pack 1 (-d))  of
                !r | r.isNaN   → r
                   | otherwise → mulP r (Positive.pack 1 (negate a.exponent))
    where        
        go ∷ Positive → Positive → Positive
        go b x0 
            -- traceLn ("b=" ++ show b ++ ", x0=" ++ show x0 ++ ", b*x0=" ++ show (mulP b x0)) = undefined
            -- mulP b x0 >= twoP = error "bad guess"
            | otherwise = case mulP x0 (diffP twoP (mulP b x0)) of
                x1 →  case x1 `mulP` b of
                     !one 
                        --| traceLn ("b=" ++ show b 
                        --     ++ ", x0=" ++ show x0 ++ ", b*x0=" ++ show (mulP b x0) 
                        --     ++ ", x1=" ++ show x1 ++ ", b*x1=" ++ show one) = undefined
                        | x1 == x0  = case round17 x1 of !x1' → down b x1' (err b x1')
                        | x1 < x0, 
                          one == Positive.one = case round17 x1 of !x1' → down b x1' (err b x1')
                        -- the approximation should normally become greater, as 1e-Z is the lower bound 
                        | x1 < x0,
                          one <= Positive.one = case round17 x0 of !x0' →  down b x0' (err b x0')
                        | x1 < x0 = -- both x0 and x1 are too big, but x1 is better
                                    case round17 x1 of x1' →  down b x1' (err b x1')    -- avoid oscillation between 2 too great values  
                        | otherwise = go b x1
            where
                round17 ∷ Positive → Positive
                round17 p = if p.coefficient > Decimal.maxCoefficient && p.exponent < Positive.maxExponent
                        then let c = if remu10 p.coefficient == 0 then 0 else 1
                                in Positive.pack (divu10 p.coefficient + c) (p.exponent + 1)
                        else p
                err ∷ Positive → Positive → Positive
                err b x = case b `mulP` x of
                    o → if o > Positive.one then o `diffP` Positive.one else Positive.one `diffP` o
                down ∷ Positive → Positive → Positive → Positive
                down b x e
                         --| traceLn ("down " ++ show b ++ "  " ++ show x ++ "  " ++ show e) = Positive.nan
                         | x.coefficient > 1, 
                           x' ← Positive.pack (x.coefficient-1) x.exponent,
                           e' ← err b x', e' <= e = down b x' e' 
                         | otherwise = x
twoP ∷ Positive
!twoP = Positive.pack 2 0

recipCheck from to = [(n,r,d*r) | n ← [from .. to], 
                                let d = fromIntegral n, 
                                let r = reciprocal d,
                                abs (1-d*r) > 2e-16z ]
        
instance Num Decimal where
    negate d = Decimal.fromBits ((negate d.coefficient << 8) ¦ (d.exponent.long & 255))
    (+) :: Decimal → Decimal → Decimal
    !a + !b 
        | a.isNaN = a
        | b.isNaN = b
        | a.sign < 0 = if b.sign < 0 then decimal (-1) (addP (positive a) (positive b))
                                     else mixed (positive a) (positive b) 
        | b.sign < 0 = mixed (positive b) (positive a)
        | otherwise = decimal 1 (addP (positive a) (positive b))
        where
            mixed ∷ Positive → Positive → Decimal
            mixed neg pos = if neg > pos 
                                then decimal (-1) (diffP neg pos)   -- (-5 + 3) = -(5-3)
                                else decimal 1    (diffP pos neg)   -- (-3 + 5) = +(5-3)

    !a - !b
        | a.isNaN = a
        | b.isNaN = b
        | otherwise = a + (negate b)
        
    !a * !b 
        | a.isZero = a
        | b.isZero = b
        | a.isNaN  = a
        | b.isNaN  = b
        | otherwise = decimal (if a.sign == b.sign then 1 else (-1)) (mulP (positive a) (positive b))
        
    fromInt i     = Decimal.pack i.long 0
    fromInteger n = Decimal.pack (fromInteger n) 0

reciprocal d 
    | d.isNaN = d
    | d.isZero = Decimal.nan
    | d == Decimal.one = d
    -- even d.coefficient = 0.5z * reciprocal (Decimal.pack (d.coefficient >> 1) d.exponent)
    | otherwise = decimal (sign d) (recipP (positive d)) 
    